{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Model weights\n",
    "- Getting model weights\n",
    "- Saving & loading weights"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.datasets import load_diabetes\n",
    "from sklearn.model_selection import train_test_split\n",
    "from keras.models import Sequential\n",
    "from keras.layers import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(353, 10) (89, 10) (353,) (89,)\n"
     ]
    }
   ],
   "source": [
    "data = load_diabetes()\n",
    "X_data = data.data\n",
    "y_data = data.target\n",
    "\n",
    "X_train, X_test, y_train, y_test = train_test_split(X_data, y_data, test_size = 0.2, random_state = 7) \n",
    "print(X_train.shape, X_test.shape, y_train.shape, y_test.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_model():\n",
    "    model = Sequential()\n",
    "    model.add(Dense(10, input_shape = (X_train.shape[1],), activation=\"relu\", name=\"dense_1\"))\n",
    "    model.add(Dense(15, name=\"dense_2\"))\n",
    "    model.add(Dense(1, activation = \"sigmoid\", name=\"output\"))\n",
    "    \n",
    "    model.compile(optimizer=\"adam\", loss=\"binary_crossentropy\", metrics=[\"acc\"])\n",
    "    return model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Creating model\n",
    "- Created model has three layers - 2 hidden layers & 1 output layer (i.e., 3 dense layers)\n",
    "    - First hidden layer(```dense_1```): has 10 X 10 shape weight matrix and 10 X 1 bias matrix\n",
    "    - Second hidden layer(```dense_2```): has 10 X 15 shape weight matrix and 15 X 1 bias matrix\n",
    "    - Output layer(```output```): has 15 X 1 shape weight matrix and 1 X 1 bias matrix"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "dense_1 (Dense)              (None, 10)                110       \n",
      "_________________________________________________________________\n",
      "dense_2 (Dense)              (None, 15)                165       \n",
      "_________________________________________________________________\n",
      "output (Dense)               (None, 1)                 16        \n",
      "=================================================================\n",
      "Total params: 291\n",
      "Trainable params: 291\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "model = create_model()\n",
    "model.summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Getting model weights\n",
    "- Two equivalent ways of getting weights of individual layer\n",
    "    - ```model.get_weights()[i]```\n",
    "    - ```model.layers[i].get_weights()```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'list'>\n",
      "6\n"
     ]
    }
   ],
   "source": [
    "# retriving model weights\n",
    "weights = model.get_weights()\n",
    "\n",
    "print(type(weights))    # arrays in a list\n",
    "print(len(weights))     # 3 weight matrices and 3 biases"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(10, 10)\n",
      "(10,)\n",
      "(10, 15)\n",
      "(15,)\n",
      "(15, 1)\n",
      "(1,)\n"
     ]
    }
   ],
   "source": [
    "# shape of weight (& bias) matrices\n",
    "for weight in weights:\n",
    "    print(weight.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[-0.05131766, -0.4334603 , -0.30013466, -0.13151154, -0.42797315,\n",
       "        -0.20081809,  0.54568088,  0.07199341,  0.45564556, -0.16397029],\n",
       "       [-0.11386764,  0.07849532, -0.32858348,  0.39409363, -0.29487017,\n",
       "        -0.20892417, -0.52551138,  0.06449258, -0.42709106, -0.42912534],\n",
       "       [-0.4862417 , -0.48859712, -0.52028137, -0.13949236, -0.47173452,\n",
       "         0.06096637, -0.2245611 ,  0.53047621,  0.32959503,  0.04725885],\n",
       "       [ 0.51978505,  0.05966097, -0.47718734,  0.1121785 ,  0.30000156,\n",
       "         0.29451877,  0.4363758 ,  0.46372569,  0.53781939,  0.0548085 ],\n",
       "       [-0.46003285, -0.51941347, -0.48344219,  0.14801621, -0.04313284,\n",
       "         0.42180699,  0.33439851,  0.45166206,  0.23053586, -0.51879096],\n",
       "       [ 0.39638752,  0.00477439,  0.14714229, -0.226125  , -0.26046604,\n",
       "         0.24747533, -0.35989189,  0.38416272, -0.5176295 ,  0.15995711],\n",
       "       [-0.44047272, -0.12300456,  0.27150732,  0.20657402, -0.06486583,\n",
       "         0.0667485 ,  0.38686836, -0.13652217,  0.41405624,  0.38533592],\n",
       "       [ 0.33987486, -0.34292638, -0.21142334,  0.17838228,  0.33222139,\n",
       "         0.08218443, -0.00158417,  0.52542567, -0.52481431,  0.18559474],\n",
       "       [ 0.23053557,  0.38431537,  0.22873986,  0.30104119,  0.30194628,\n",
       "        -0.19992581,  0.3789947 , -0.48418537,  0.08003563,  0.15359342],\n",
       "       [ 0.07272196,  0.35086989,  0.49293137,  0.51381564, -0.01203793,\n",
       "         0.42187357,  0.38285059, -0.44306043,  0.49825776,  0.29124063]], dtype=float32)"
      ]
     },
     "execution_count": 46,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# first weight matrix\n",
    "weights[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[array([[-0.05131766, -0.4334603 , -0.30013466, -0.13151154, -0.42797315,\n",
       "         -0.20081809,  0.54568088,  0.07199341,  0.45564556, -0.16397029],\n",
       "        [-0.11386764,  0.07849532, -0.32858348,  0.39409363, -0.29487017,\n",
       "         -0.20892417, -0.52551138,  0.06449258, -0.42709106, -0.42912534],\n",
       "        [-0.4862417 , -0.48859712, -0.52028137, -0.13949236, -0.47173452,\n",
       "          0.06096637, -0.2245611 ,  0.53047621,  0.32959503,  0.04725885],\n",
       "        [ 0.51978505,  0.05966097, -0.47718734,  0.1121785 ,  0.30000156,\n",
       "          0.29451877,  0.4363758 ,  0.46372569,  0.53781939,  0.0548085 ],\n",
       "        [-0.46003285, -0.51941347, -0.48344219,  0.14801621, -0.04313284,\n",
       "          0.42180699,  0.33439851,  0.45166206,  0.23053586, -0.51879096],\n",
       "        [ 0.39638752,  0.00477439,  0.14714229, -0.226125  , -0.26046604,\n",
       "          0.24747533, -0.35989189,  0.38416272, -0.5176295 ,  0.15995711],\n",
       "        [-0.44047272, -0.12300456,  0.27150732,  0.20657402, -0.06486583,\n",
       "          0.0667485 ,  0.38686836, -0.13652217,  0.41405624,  0.38533592],\n",
       "        [ 0.33987486, -0.34292638, -0.21142334,  0.17838228,  0.33222139,\n",
       "          0.08218443, -0.00158417,  0.52542567, -0.52481431,  0.18559474],\n",
       "        [ 0.23053557,  0.38431537,  0.22873986,  0.30104119,  0.30194628,\n",
       "         -0.19992581,  0.3789947 , -0.48418537,  0.08003563,  0.15359342],\n",
       "        [ 0.07272196,  0.35086989,  0.49293137,  0.51381564, -0.01203793,\n",
       "          0.42187357,  0.38285059, -0.44306043,  0.49825776,  0.29124063]], dtype=float32),\n",
       " array([ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.], dtype=float32)]"
      ]
     },
     "execution_count": 47,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# another way to get weights\n",
    "model.layers[0].get_weights()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "True\n"
     ]
    }
   ],
   "source": [
    "# they are equal\n",
    "print(np.array_equal(weights[0], model.layers[0].get_weights()[0]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Saving & loading weights\n",
    "- ```model.save_weights()```\n",
    "- ```model.load_weights()```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.save_weights(\"weights\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "dense_1 (Dense)              (None, 10)                110       \n",
      "_________________________________________________________________\n",
      "dense_2 (Dense)              (None, 15)                165       \n",
      "_________________________________________________________________\n",
      "output (Dense)               (None, 1)                 16        \n",
      "=================================================================\n",
      "Total params: 291\n",
      "Trainable params: 291\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "# create another model with same architecture\n",
    "another_model = create_model()\n",
    "another_model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 66,
   "metadata": {},
   "outputs": [],
   "source": [
    "another_model.load_weights(\"weights\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 71,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "True\n",
      "True\n"
     ]
    }
   ],
   "source": [
    "# two models have equal weights\n",
    "print(np.array_equal(model.get_weights()[0], another_model.get_weights()[0]))\n",
    "print(np.array_equal(model.get_weights()[-1], another_model.get_weights()[-1]))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
